import numpy as np
from PDE2d import PDE2d
from ODE2d import ODE2d
from simulation_plots import make_plots
import tqdm
from scipy.special import xlogy
import matplotlib.pyplot as plt

# use for all simulations
dx = 0.2
dy = 0.2
N  = 25
dt = 0.005


######################## option 1 ###############################################
def f1(x,z):
    '''Evaluate the cost function f1 at x and z (can be vector-valued)'''
    out = 1.-np.around((1.+np.exp(np.tensordot(x,z,axes=[0,-1])))**-1,decimals=16)
    return out/2.


def grad_f1_x(x,z):
    '''Gradient of f1 with respect to x, evaluted at inputs x and z'''
    exponent_ = np.exp(np.tensordot(x,z,axes=(0,-1)))
    out = np.copy(z)
    out[:,:,0] *= exponent_ / (exponent_+1)**2
    out[:,:,1] *= exponent_ / (exponent_+1)**2
    return out/2.


def f2(x,z):
    '''Evaluate the cost function f2 at x and z (can be vector-valued)'''
    return np.around((1.+np.exp(np.tensordot(x,z,axes=[0,-1])))**-1,decimals=16)/2.


def grad_f2_x(x,z):
    '''Gradient of f2 with respect to x, evaluted at inputs x and z'''
    exponent_ = np.exp(np.tensordot(x,z,axes=(0,-1)))
    out = np.copy(-z)
    out[:,:,0] *= exponent_ / (exponent_+1)**2
    out[:,:,1] *= exponent_ / (exponent_+1)**2
    return out/2.


########################## option 2 #####################################
def inside_exp(x,z):
    '''terms inside exp'''
    x1 = x[0]
    x2 = x[1]
    x_bar = (1-x1)*z[:,:,0] + x1*z[:,:,1]
    return x_bar + x2


def f1_b(x,z):
    '''Evaluate the cost function f1 at x and z (can be vector-valued)'''
    exp = np.exp(inside_exp(x,z))
    out = np.around((1.+exp)**-1,decimals=16)
    return out


def grad_f1_x_b(x,z):
    '''Gradient of f1 with respect to x, evaluted at inputs x and z'''
    exp_ = np.exp(inside_exp(x,z))
    mat = np.stack([-z[:,:,0]+z[:,:,1],np.ones(np.shape(z[:,:,0])) ],axis=-1)
    mat[:,:,0] *= -(1+exp_)**-2 * exp_
    mat[:,:,1] *= -(1+exp_)**-2 * exp_
    return mat


def f2_b(x,z):
    '''Evaluate the cost function f2 at x and z (can be vector-valued)'''
    exp = np.exp(inside_exp(x,z))
    out = 1-np.around((1.+exp)**-1,decimals=16)
    return out


def grad_f2_x_b(x,z):
    '''Gradient of f2 with respect to x, evaluted at inputs x and z'''
    exp_ = np.exp(inside_exp(x,z))
    mat = np.stack([-z[:,:,0]+z[:,:,1],np.ones(np.shape(z[:,:,0])) ],axis=-1)
    mat[:,:,0] *= (1+exp_)**-2 * exp_
    mat[:,:,1] *= (1+exp_)**-2 * exp_
    return mat


################## KL divergence, potential, and kernel #####################
def H1(rho_east_west,rho_bar,rho_tilde):
    '''KL divergence term
    rho_tilde is the initial condition
    rho_bar is the current distribution'''
    return ( xlogy(rho_east_west,rho_bar)-xlogy(rho_east_west,rho_tilde) ) /2.


def V1(x,z):
    return f1(x,z)


def V2(x,z):
    return f1_b(x,z)


def W_no_kernel(Dx,Dy,dx,dy):
    return 0.


#################### initial distributions #################################
def initial_dist(z_x,z_y,mu_x,mu_y):
    '''initial condition for 0 labels'''
    sig = np.sqrt(8.)
    return np.exp(-(z_x-mu_x)**2/(2.*sig)-(z_y-mu_y)**2/(2.*sig))/(2*np.pi*sig)


def double_gaussian(z_x,z_y,mu_x1,mu_y1,mu_x2,mu_y2):
    sig = np.sqrt(0.5)
    g1 = np.exp(-(z_x-mu_x1)**2/(2.*sig)-(z_y-mu_y1)**2/(2.*sig))/(2*np.pi*sig)
    g2 = np.exp(-(z_x-mu_x2)**2/(2.*sig)-(z_y-mu_y2)**2/(2.*sig))/(2*np.pi*sig)
    return 0.5*g1 + 0.5*g2


####################### select experiments ###############################
experiments = [1] #,2]

################ experiment 1: Gaussian initial conditions with set 1 f1, f2 #####################
if 1 in experiments:
    print("Experiment 1")
    x0            = np.array([-3.,3.])
    T             = 4.
    nT            = int(T/dt)
    mu            = 1.
    x_conv_rate   = 1.e0 # convergence rate for x

    pde = PDE2d(dx,dy,N,nT,H_prime_rho=H1,V=V1,W=W_no_kernel,save_data=False)
    pde.set_initial_distribution(lambda zx,zy: initial_dist(zx,zy,-mu,-mu),lambda zx,zy: initial_dist(zx,zy,mu,mu))
    ode = ODE2d(pde.z_i,dt,nT,pde.g0,x0,f1,f2,grad_f1_x,grad_f2_x,save_data=True,x_speed=x_conv_rate)
    rho = pde.rho0
    for t in tqdm.tqdm(range(0,nT)):
        x   = ode.update_x(rho,t)
        rho = pde.update_RK(x,t,dt)

    make_plots(pde,ode,"plots/experiment2_1",make_gif=False)


################ experiment 2: bimodal Gaussian initial conditions with set 2 f1, f2 #####################
########## not in paper ###########
if 2 in experiments:
    print("Experiment 2")
    x0            = np.array([-3.,-3.])
    T             = 0.1
    nT            = int(T/dt)
    mu            = 1.
    x_conv_rate   = 1.e0

    pde = PDE2d(dx,dy,N,nT,H_prime_rho=H1,V=V2,W=W_no_kernel,save_data=False)
    pde.set_initial_distribution(lambda zx,zy: double_gaussian(zx,zy,-mu,-mu/2,mu/2,mu/2),lambda zx,zy: initial_dist(zx,zy,mu,mu))
    ode = ODE2d(pde.z_i,dt,nT,pde.g0,x0,f1_b,f2_b,grad_f1_x_b,grad_f2_x_b,save_data=True,x_speed=x_conv_rate)
    rho = pde.rho0
    for t in tqdm.tqdm(range(0,nT)):
        x   = ode.update_x(rho,t)
        rho = pde.update_RK(x,t,dt)

    make_plots(pde,ode,"plots/experiment2_2",make_gif=False)
